/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "InitFuncDef.h"
#include "ascir/Dialect/Asc/IR/Asc.h"
#include "ascir/Dialect/Asc/Utils/Attributes.h"
#include "ascir/Dialect/Asc/Utils/Utils.h"
#include "ascir/Dialect/EmitAsc/IR/EmitAsc.h"
#include "ascir/Dialect/EmitAsc/Utils/Attributes.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"

#include <pybind11/cast.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>  // automatic casts between containers and python types

#include <cstdint>
#include <optional>
#include <stdexcept>
#include <string>

namespace py = pybind11;
using namespace mlir;

namespace {

constexpr unsigned INDEX_64 = 64;

OpPrintingFlags getOpPrintingFlags()
{
    auto printingFlags = OpPrintingFlags();
    printingFlags.enableDebugInfo();
    return printingFlags;
}

std::optional<SmallVector<emitasc::KernelArgument>> getKernelArgAttrs(ModuleOp op)
{
    func::FuncOp kernelFunc;
    op.walk([&](func::FuncOp fn) -> WalkResult {
        if (fn->hasAttrOfType<UnitAttr>(ascendc::attr::global)) {
            kernelFunc = fn;
            return WalkResult::interrupt();
        }
        return WalkResult::advance();
    });
    if (!kernelFunc) {
        return std::nullopt;
    }
    SmallVector<emitasc::KernelArgument> kernelArgs;
    unsigned numArgs = kernelFunc.getNumArguments();
    kernelArgs.reserve(numArgs);
    for (unsigned i = 0; i < numArgs; i++) {
        auto attr = kernelFunc.getArgAttrOfType<emitasc::KernelArgumentAttr>(i, emitasc::attr::kernelArg);
        kernelArgs.push_back(attr ? attr.getValue() : emitasc::KernelArgument::Explicit);
    }
    return kernelArgs;
}

}  // namespace

namespace pybind11 {
namespace asc {
void pyasc_bind_enums(py::module &m)
{
    using ret = py::return_value_policy;
    using namespace pybind11::literals;

    m.attr("dynshape") = py::int_(ShapedType::kDynamic);

    py::enum_<ascendc::AddressSpace>(m, "AddressSpace", py::module_local())
        .value("ca", ascendc::AddressSpace::ca)
        .value("cb", ascendc::AddressSpace::cb)
        .value("cbuf", ascendc::AddressSpace::cbuf)
        .value("cc", ascendc::AddressSpace::cc)
        .value("Default", ascendc::AddressSpace::Default)
        .value("fbuf", ascendc::AddressSpace::fbuf)
        .value("gm", ascendc::AddressSpace::gm)
        .value("ubuf", ascendc::AddressSpace::ubuf);

    py::enum_<ascendc::AippInputFormat>(m, "AippInputFormat", py::module_local())
        .value("YUV420SP_U8", ascendc::AippInputFormat::YUV420SP_U8)
        .value("XRGB8888_U8", ascendc::AippInputFormat::XRGB8888_U8)
        .value("RGB888_U8", ascendc::AippInputFormat::RGB888_U8)
        .value("YUV400_U8", ascendc::AippInputFormat::YUV400_U8)
        .def_static("symbolize", [](uint8_t input_format) -> ascendc::AippInputFormat {
            return static_cast<ascendc::AippInputFormat>(input_format);
        });

    py::enum_<ascendc::CacheLine>(m, "CacheLine", py::module_local())
        .value("SINGLE_CACHE_LINE", ascendc::CacheLine::SINGLE_CACHE_LINE)
        .value("ENTIRE_DATA_CACHE", ascendc::CacheLine::ENTIRE_DATA_CACHE)
        .def_static("symbolize", [](uint8_t v) -> ascendc::CacheLine {
            return static_cast<ascendc::CacheLine>(v);
        });

    py::enum_<arith::CmpFPredicate>(m, "CmpFPredicate", py::module_local())
        .value("OEQ", arith::CmpFPredicate::OEQ)
        .value("ONE", arith::CmpFPredicate::ONE)
        .value("OGE", arith::CmpFPredicate::OGE)
        .value("OGT", arith::CmpFPredicate::OGT)
        .value("OLE", arith::CmpFPredicate::OLE)
        .value("OLT", arith::CmpFPredicate::OLT);

    py::enum_<arith::CmpIPredicate>(m, "CmpIPredicate", py::module_local())
        .value("eq", arith::CmpIPredicate::eq)
        .value("ne", arith::CmpIPredicate::ne)
        .value("sge", arith::CmpIPredicate::sge)
        .value("sgt", arith::CmpIPredicate::sgt)
        .value("sle", arith::CmpIPredicate::sle)
        .value("slt", arith::CmpIPredicate::slt);

    py::enum_<ascendc::DcciDst>(m, "DcciDst", py::module_local())
        .value("CACHELINE_ALL", ascendc::DcciDst::CACHELINE_ALL)
        .value("CACHELINE_UB", ascendc::DcciDst::CACHELINE_UB)
        .value("CACHELINE_OUT", ascendc::DcciDst::CACHELINE_OUT)
        .value("CACHELINE_ATOMIC", ascendc::DcciDst::CACHELINE_ATOMIC)
        .def_static("symbolize", [](uint8_t v) -> ascendc::DcciDst {
            return static_cast<ascendc::DcciDst>(v);
        });

    py::enum_<ascendc::MaskMode>(m, "MaskMode", py::module_local())
        .value("NORMAL", ascendc::MaskMode::NORMAL)
        .value("COUNTER", ascendc::MaskMode::COUNTER)
        .def_static("symbolize", [](uint8_t v) -> ascendc::MaskMode { return static_cast<ascendc::MaskMode>(v);});
    
    py::enum_<ascendc::ReduceOrder>(m, "ReduceOrder", py::module_local())
        .value("ORDER_VALUE_INDEX", ascendc::ReduceOrder::ORDER_VALUE_INDEX)
        .value("ORDER_INDEX_VALUE", ascendc::ReduceOrder::ORDER_INDEX_VALUE)
        .value("ORDER_ONLY_VALUE", ascendc::ReduceOrder::ORDER_ONLY_VALUE)
        .value("ORDER_ONLY_INDEX", ascendc::ReduceOrder::ORDER_ONLY_INDEX)
        .def_static("symbolize", [](uint8_t v) -> ascendc::ReduceOrder {
            return static_cast<ascendc::ReduceOrder>(v);
        });
    
    py::enum_<ascendc::RoundMode>(m, "RoundMode", py::module_local())
        .value("CAST_NONE", ascendc::RoundMode::CAST_NONE)
        .value("CAST_RINT", ascendc::RoundMode::CAST_RINT)
        .value("CAST_FLOOR", ascendc::RoundMode::CAST_FLOOR)
        .value("CAST_CEIL", ascendc::RoundMode::CAST_CEIL)
        .value("CAST_ROUND", ascendc::RoundMode::CAST_ROUND)
        .value("CAST_TRUNC", ascendc::RoundMode::CAST_TRUNC)
        .value("CAST_ODD", ascendc::RoundMode::CAST_ODD)
        .def_static("symbolize", [](uint8_t v) -> ascendc::RoundMode {
            return static_cast<ascendc::RoundMode>(v);
        });
    
    py::enum_<ascendc::TPosition>(m, "TPosition", py::module_local())
        .def_static("symbolize", [](uint8_t pos) -> ascendc::TPosition { return static_cast<ascendc::TPosition>(pos); });
}

void pyasc_bind_context_and_dialect(py::module &m)
{
    py::class_<MLIRContext>(m, "Context", py::module_local())
        .def(py::init<>())
        .def("disable_multithreading", [](MLIRContext &self) { self.disableMultithreading(); });

    m.def("load_dialects", [](MLIRContext &context) {
        DialectRegistry registry;
        registry.insert<
            //
            arith::ArithDialect,
            ascendc::AscendCDialect,
            emitasc::EmitAscDialect,
            emitc::EmitCDialect,
            func::FuncDialect,
            memref::MemRefDialect,
            scf::SCFDialect,
            vector::VectorDialect
            //
            >();
        ascendc::registerExternalModels(registry);
        ascendc::registerInlinerInterfaces(registry);
        emitasc::registerExternalModels(registry);
        func::registerAllExtensions(registry);
        context.appendDialectRegistry(registry);
        context.loadAllAvailableDialects();
    });
}

void pyasc_bind_type(py::module &m)
{
    using namespace pybind11::literals;
    py::class_<Type>(m, "Type", py::module_local())
        .def("is_integer", [](Type &self) -> bool { return self.isInteger(); })
        .def("is_index", &Type::isIndex)
        .def("__eq__",
            [](Type &self, py::object &other) {
                Type *other_ty = py::cast<Type *>(other);
                return (other_ty != nullptr) && (*other_ty == self);
            })
        .def("__ne__",
            [](Type &self, py::object &other) {
                Type *other_ty = py::cast<Type *>(other);
                return (other_ty == nullptr) || (*other_ty != self);
            })
        .def("get_py_name",
            [](Type &self) -> std::optional<std::string> {
                if (isa<IntegerType>(self)) {
                    std::string name = self.isUnsignedInteger() ? "uint" : "int";
                    name += std::to_string(self.getIntOrFloatBitWidth());
                    return name;
                }
                if (isa<FloatType>(self)) {
                    std::string name = "float";
                    name += std::to_string(self.getIntOrFloatBitWidth());
                    return name;
                }
                if (isa<NoneType>(self))
                    return "void";
                return std::nullopt;
            })
        .def("__str__", [](Type &self) {
            std::string str;
            llvm::raw_string_ostream os(str);
            self.print(os);
            os.flush();
            return os.str();
        });
}

void pyasc_bind_memref(py::module &m)
{
    using namespace pybind11::literals;
    m.def("get_element_type", [](const Type &shapedType) -> Type {
        auto type = llvm::dyn_cast_if_present<ShapedType>(shapedType);
        if (!type)
            throw std::runtime_error("get_element_type(): must be shaped type");
        return type.getElementType();
    });

    m.def("get_shape", [](const Type &shapedType) -> std::vector<int64_t> {
        auto type = llvm::dyn_cast_if_present<ShapedType>(shapedType);
        if (!type)
            throw std::runtime_error("get_shape(): must be shaped type");
        return type.getShape().vec();
    });

    m.def("get_vector_type",
        [](Type &elementType, std::vector<int64_t> &shape) -> Type { return VectorType::get(shape, elementType); });

    m.def(
        "get_memref_type",
        [](Type &elementType,
            const std::variant<std::vector<int64_t>, int64_t> &shape,
            std::optional<int64_t>
                addressSpace) -> Type {
            Attribute memorySpace;
            if (auto as = addressSpace.value_or(0)) {
                memorySpace = IntegerAttr::get(IntegerType::get(elementType.getContext(), INDEX_64), as);
            }
            SmallVector<int64_t> sh;
            if (std::holds_alternative<int64_t>(shape)) {
                sh.push_back(std::get<int64_t>(shape));
            } else {
                const auto &shapeVec = std::get<std::vector<int64_t>>(shape);
                sh.append(shapeVec.begin(), shapeVec.end());
            }
            return MemRefType::get(sh, elementType, AffineMap{}, memorySpace);
        },
        "element_type"_a,
        "shape"_a,
        "address_space"_a = py::none());
    m.def(
        "get_unranked_memref_type",
        [](Type &elementType, std::optional<int64_t> addressSpace) -> Type {
            Attribute memorySpace;
            if (auto as = addressSpace.value_or(0))
                memorySpace = IntegerAttr::get(IntegerType::get(elementType.getContext(), INDEX_64), as);
            return UnrankedMemRefType::get(elementType, memorySpace);
        },
        "element_type"_a,
        "address_space"_a = py::none());
}

void pyasc_bind_tensor_type(py::module &m)
{
    using namespace pybind11::literals;
    m.def("get_global_tensor_type", [](Type &elementType, std::vector<int64_t> &shape) -> Type {
        return ascendc::GlobalTensorType::get(shape, elementType);
    });

    m.def("get_global_tensor_type",
        [](Type &elementType) -> Type { return ascendc::GlobalTensorType::get(elementType); });

    m.def("get_local_tensor_type", [](Type &elementType, std::vector<int64_t> &shape) -> Type {
        return ascendc::LocalTensorType::get(shape, elementType);
    });

    m.def(
        "get_local_tensor_type", [](Type &elementType) -> Type { return ascendc::LocalTensorType::get(elementType); });

    m.def("get_opaque_type_name",
        [](Type &type) -> std::string { return cast<emitc::OpaqueType>(type).getValue().str(); });
}

void pyasc_bind_location(py::module &m)
{
    py::class_<Location>(m, "Location", py::module_local()).def("__str__", [](Location &self) {
        std::string str;
        llvm::raw_string_ostream os(str);
        self.print(os);
        return os.str();
    });
}

void pyasc_bind_value(py::module &m)
{
    using ret = py::return_value_policy;
    py::class_<Value>(m, "Value", py::module_local())
        .def("get_context", &Value::getContext, ret::reference)
        .def(
            "get_defining_op",
            [](Value &self) -> std::optional<Operation *> {
                auto *def = self.getDefiningOp();
                if (def)
                    return def;
                return std::nullopt;
            },
            ret::reference)
        .def("replace_all_uses_with", [](Value &self, Value &newValue) { self.replaceAllUsesWith(newValue); })
        .def("replace_uses_in_block", [](Value &self, Block *block, Value &newValue) {
            self.replaceUsesWithIf(newValue, [block](OpOperand &opnd) -> bool {
                auto *op = opnd.getOwner();
                Block *parentBlock = op->getBlock();
                while (parentBlock) {
                    if (parentBlock == block)
                        return true;
                    if (auto *parentOp = parentBlock->getParentOp())
                        parentBlock = parentOp->getBlock();
                    else
                        parentBlock = nullptr;
                }
                return false;
            });
        })
        .def("get_type", &Value::getType)
        .def("dump", &Value::dump)
        .def("id", [](Value &self) { return reinterpret_cast<uint64_t>(self.getImpl()); });
}

void pyasc_bind_region(py::module &m)
{
    using ret = py::return_value_policy;
    py::class_<OpResult, Value>(m, "OpResult", py::module_local());
    py::class_<BlockArgument, Value>(m, "BlockArgument", py::module_local());
    py::class_<Region>(m, "Region", py::module_local())
        .def("get_parent_region", &Region::getParentRegion, ret::reference)
        .def(
            "get_block",
            [](Region &self, unsigned index) -> Block& {
                if (index >= self.getBlocks().size())
                    throw std::runtime_error("block index is out of range");
                return *std::next(self.begin(), index);
            },
            ret::reference)
        .def("size", [](Region &self) { return self.getBlocks().size(); })
        .def("empty", &Region::empty)
        .def("id", [](Region &self) { return (uint64_t)&self; });
}

void pyasc_bind_blocks(py::module &m)
{
    using ret = py::return_value_policy;
    py::class_<Block>(m, "Block", py::module_local())
        .def(py::init())
        .def("dump", &Block::dump)
        .def("id", [](Block &self) { return (uint64_t)&self; })
        .def("has_terminator", &Block::mightHaveTerminator)
        .def("get_terminator", &Block::getTerminator, ret::reference)
        .def("add_argument",
            [](Block &self, Type &type) -> BlockArgument {
                return self.addArgument(type, UnknownLoc::get(type.getContext()));
            })
        .def("get_argument", &Block::getArgument)
        .def("get_arguments", [](Block &self) -> std::vector<BlockArgument> { return self.getArguments().vec(); })
        .def("merge_block_before",
            [](Block &self, Block &dst) {
                // See RewriterBase::mergeBlocks()
                if (self.getNumArguments() != 0)
                    throw std::runtime_error("Unable to merge block with arguments");
                dst.getOperations().splice(dst.begin(), self.getOperations());
                self.dropAllUses();
                if (self.getParent())
                    self.erase();
            })
        .def("clear", &Block::clear)
        .def("erase", &Block::erase);
}

void pyasc_bind_inline_block(py::module &m)
{
    m.def("inline_block_at_end", [](Block *src, Block *dst, const std::optional<std::vector<Value>> &args) {
        // See RewriterBase::inlineBlockBefore()
        ValueRange argValues({});
        if (args)
            argValues = *args;
        auto before = dst->end();
        if (argValues.size() != src->getNumArguments())
            throw std::runtime_error("incorrect # of argument replacement values");
        // Replace all of the successor arguments with the provided values
        for (auto [arg, newVal] : llvm::zip(src->getArguments(), argValues))
            arg.replaceAllUsesWith(newVal);
        dst->getOperations().splice(before, src->getOperations());
        if (!src->empty()) {
            throw std::runtime_error("expected 'src' to be empty");
        }
        if (src->getParent())
            src->erase();
    }, "src"_a, "dst"_a, "args"_a = py::none());
}

void pyasc_bind_attritube(py::module &m)
{
    using ret = py::return_value_policy;
    py::class_<Attribute>(m, "Attribute", py::module_local())
        .def("dump", &Attribute::dump)
        .def("id", [](Attribute &self) { return reinterpret_cast<uint64_t>(self.getAsOpaquePointer()); });

    py::class_<ArrayAttr, Attribute>(m, "ArrayAttr", py::module_local());

    m.def("get_type_attr", [](const Type &type) -> Attribute { return TypeAttr::get(type); });
}

void pyasc_bind_operation(py::module &m)
{
    using ret = py::return_value_policy;
    py::class_<Operation, std::unique_ptr<Operation, py::nodelete>>(m, "Operation", py::module_local())
        .def("get_name",
            [](Operation &self) {
                llvm::StringRef opName = self.getName().getStringRef();
                return opName.str();
            })
        .def("get_num_operands", &Operation::getNumOperands)
        .def("get_operand", &Operation::getOperand)
        .def("get_num_results", &Operation::getNumResults)
        .def("get_result", &Operation::getResult)
        .def("get_num_regions", &Operation::getNumRegions)
        .def("get_region", &Operation::getRegion, ret::reference)
        .def("get_block", &Operation::getBlock, ret::reference)
        .def("has_unit_attr",
            [](Operation &self, const std::string &name) -> bool { return self.hasAttrOfType<UnitAttr>(name); })
        .def("get_str_attr",
            [](Operation &self, const std::string &name) -> std::optional<std::string> {
                auto ret = self.getAttrOfType<StringAttr>(name);
                if (!ret)
                    return std::nullopt;
                return ret.getValue().str();
            })
        .def("get_bool_attr",
            [](Operation &self, const std::string &name) -> std::optional<bool> {
                auto ret = self.getAttrOfType<BoolAttr>(name);
                if (!ret)
                    return std::nullopt;
                return ret.getValue();
            })
        .def("get_integer_attr",
            [](Operation &self, const std::string &name) -> py::object {
                auto ret = self.getAttrOfType<IntegerAttr>(name);
                if (!ret)
                    return py::none();
                return py::int_(ret.getValue().getSExtValue());
            })
        .def("get_flat_symbol_ref_attr", [](Operation &self, const std::string &name) -> py::object {
            auto ret = self.getAttrOfType<FlatSymbolRefAttr>(name);
            if (!ret)
                return py::none();
            return py::str(ret.getValue().str());
        });
}

void pyasc_bind_opstate(py::module &m)
{
    using ret = py::return_value_policy;
    py::class_<OpState>(m, "OpState", py::module_local())
        .def("get_context", &OpState::getContext, ret::reference)
        .def("set_attr", [](OpState &self, std::string &name, Attribute &attr) { self->setAttr(name, attr); })
        .def("get_num_results", [](OpState &self) -> unsigned { return self->getNumResults(); })
        .def("get_result",
            [](OpState &self, unsigned idx) -> Value {
                if (idx >= self->getNumResults())
                    throw pybind11::index_error("Op result index out of range");
                return self->getResult(idx);
            })
        .def(
            "get_region",
            [](OpState &self, unsigned idx) -> Region& {
                if (idx >= self->getNumRegions())
                    throw pybind11::index_error("Op region index out of range");
                return self->getRegion(idx);
            },
            ret::reference)
        .def("dump", [](OpState &self) { self->dump(); })
        .def("__str__",
            [](OpState &self) -> std::string {
                std::string str;
                llvm::raw_string_ostream os(str);
                auto printingFlags = getOpPrintingFlags();
                self->print(os, printingFlags);
                return str;
            })
        .def("append_operand", [](OpState &self, Value &val) { self->insertOperands(self->getNumOperands(), val); })
        .def("verify", [](OpState &self) -> bool { return succeeded(verify(self.getOperation())); })
        .def_property_readonly("op", &OpState::getOperation, ret::reference);
}

void pyasc_bind_moduleop(py::module &m)
{
    using ret = py::return_value_policy;
    py::class_<ModuleOp, OpState>(m, "ModuleOp", py::module_local())
        .def("dump", &ModuleOp::dump)
        .def(
            "get_body", [](ModuleOp &self) -> Block* { return self.getBody(); }, ret::reference)
        .def(
            "has_function",
            [](ModuleOp &self, const std::string &name, const std::optional<Type> &type) -> bool {
                auto *op = SymbolTable::lookupSymbolIn(self, name);
                if (auto funcOp = dyn_cast_if_present<func::FuncOp>(op))
                    return !type || funcOp.getFunctionType() == *type;
                return false;
            },
            "name"_a,
            "type"_a = py::none())
        .def("need_insert_sync", [](ModuleOp &self) {
            auto result = self.walk([](ascendc::LocalTensorAutoOp) {
                return WalkResult::interrupt();
            });
            return result.wasInterrupted();
        })
        .def("erase", [](ModuleOp &self) { self->erase(); });
}

void pyasc_bind_funcop(py::module &m)
{
    using ret = py::return_value_policy;
    py::class_<func::FuncOp, OpState>(m, "FuncOp", py::module_local())
        .def("get_arg",
            [](func::FuncOp &self, unsigned idx) -> BlockArgument {
                if (idx >= self.getNumArguments())
                    throw pybind11::index_error("Function argument index out of range");
                return self.getArgument(idx);
            })
        .def("get_num_args", &func::FuncOp::getNumArguments)
        .def(
            "add_entry_block", [](func::FuncOp &self) -> Block* { return self.addEntryBlock(); }, ret::reference)
        .def("set_type",
            [](func::FuncOp &self, const Type &funcType) {
                auto type = dyn_cast<FunctionType>(funcType);
                if (!type)
                    throw std::runtime_error("set_type(): must be FunctionType");
                self.setFunctionType(type);
            })
        .def("set_arg_names",
            [](func::FuncOp &self, const std::vector<std::string> &names) {
                if (names.size() != self.getNumArguments())
                    throw std::runtime_error("Number of names must be equal to number of arguments");
                for (unsigned i = 0; i < names.size(); i++) {
                    auto arg = self.getArgument(i);
                    auto name = StringAttr::get(self.getContext(), names[i]);
                    arg.setLoc(NameLoc::get(name, arg.getLoc()));
                }
            })
        .def(
            "get_body", [](func::FuncOp &self) -> Block& { return self.getFunctionBody().front(); }, ret::reference)
        .def("make_aicore",
            [](func::FuncOp &self) { self->setAttr(ascendc::attr::aicore, UnitAttr::get(self.getContext())); })
        .def("make_global", [](func::FuncOp &self) {
            self.setPublic();
            self->setAttr(ascendc::attr::global, UnitAttr::get(self.getContext()));
        });
}

void pyasc_bind_scfop(py::module &m)
{
    using ret = py::return_value_policy;
    py::class_<scf::ForOp, OpState>(m, "ForOp", py::module_local())
        .def("get_induction_var", &scf::ForOp::getInductionVar)
        .def(
            "get_body", [](scf::ForOp &self) -> Block* { return self.getBody(); }, ret::reference);
    py::class_<scf::IfOp, OpState>(m, "IfOp", py::module_local())
        .def("get_then_block", &scf::IfOp::thenBlock, ret::reference)
        .def("get_else_block", &scf::IfOp::elseBlock, ret::reference)
        .def("get_then_yield", &scf::IfOp::thenYield)
        .def("get_else_yield", &scf::IfOp::elseYield);
    py::class_<scf::YieldOp, OpState>(m, "YieldOp", py::module_local());
    py::class_<scf::WhileOp, OpState>(m, "WhileOp", py::module_local())
        .def("get_before", &scf::WhileOp::getBefore, ret::reference)
        .def("get_after", &scf::WhileOp::getAfter, ret::reference);
    py::class_<scf::ConditionOp, OpState>(m, "ConditionOp", py::module_local());
}

void pyasc_bind_kernel_argument(py::module &m) {
    py::enum_<emitasc::KernelArgument>(m, "KernelArgument", py::module_local())
        .value("Explicit", emitasc::KernelArgument::Explicit)
        .value("FftsAddr", emitasc::KernelArgument::FftsAddr);

    m.def("get_kernel_arg_attrs", [](ModuleOp &mod) -> py::object {
        auto kernelArgs = getKernelArgAttrs(mod);
        if (!kernelArgs) {
            return py::none();
        }
        py::list result;
        for (auto arg : kernelArgs.value()) {
            result.append(arg);
        }
        return result;
    });
}

void pyasc_init_ir(py::module &&m)
{
    pyasc_bind_enums(m);
    pyasc_bind_context_and_dialect(m);
    pyasc_bind_type(m);
    pyasc_bind_memref(m);
    pyasc_bind_tensor_type(m);
    pyasc_bind_location(m);
    pyasc_bind_value(m);
    pyasc_bind_region(m);
    pyasc_bind_blocks(m);
    pyasc_bind_inline_block(m);
    pyasc_bind_attritube(m);
    pyasc_bind_operation(m);
    pyasc_bind_opstate(m);
    pyasc_bind_moduleop(m);
    pyasc_bind_funcop(m);
    pyasc_bind_scfop(m);
    pyasc_bind_kernel_argument(m);
    py::class_<OpBuilder::InsertPoint>(m, "InsertPoint", py::module_local());

    pyasc_init_ir_builder(m);
}
}   // namespace asc
}   // namespace pybind11
